{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import copy as cp\n",
    "from collections import OrderedDict\n",
    "\n",
    "import torch\n",
    "from mmengine.runner.checkpoint import _load_checkpoint\n",
    "\n",
    "from mmaction.utils import register_all_modules\n",
    "from mmaction.registry import MODELS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "backbone_cfg = dict(\n",
    "    type='RGBPoseConv3D',\n",
    "    speed_ratio=4,\n",
    "    channel_ratio=4,\n",
    "    rgb_pathway=dict(\n",
    "        num_stages=4,\n",
    "        lateral=True,\n",
    "        lateral_infl=1,\n",
    "        lateral_activate=[0, 0, 1, 1],\n",
    "        fusion_kernel=7,\n",
    "        base_channels=64,\n",
    "        conv1_kernel=(1, 7, 7),\n",
    "        inflate=(0, 0, 1, 1),\n",
    "        with_pool2=False),\n",
    "    pose_pathway=dict(\n",
    "        num_stages=3,\n",
    "        stage_blocks=(4, 6, 3),\n",
    "        lateral=True,\n",
    "        lateral_inv=True,\n",
    "        lateral_infl=16,\n",
    "        lateral_activate=(0, 1, 1),\n",
    "        fusion_kernel=7,\n",
    "        in_channels=17,\n",
    "        base_channels=32,\n",
    "        out_indices=(2, ),\n",
    "        conv1_kernel=(1, 7, 7),\n",
    "        conv1_stride_s=1,\n",
    "        conv1_stride_t=1,\n",
    "        pool1_stride_s=1,\n",
    "        pool1_stride_t=1,\n",
    "        inflate=(0, 1, 1),\n",
    "        spatial_strides=(2, 2, 2),\n",
    "        temporal_strides=(1, 1, 1),\n",
    "        dilations=(1, 1, 1),\n",
    "        with_pool2=False))\n",
    "head_cfg = dict(\n",
    "    type='RGBPoseHead',\n",
    "    num_classes=60,\n",
    "    in_channels=[2048, 512],\n",
    "    average_clips='prob')\n",
    "model_cfg = dict(\n",
    "    type='Recognizer3D',\n",
    "    backbone=backbone_cfg,\n",
    "    cls_head=head_cfg)\n",
    "\n",
    "register_all_modules()\n",
    "model = MODELS.build(model_cfg)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "# set your paths of the pretrained weights here\n",
    "rgb_filepath = 'https://download.openmmlab.com/mmaction/v1.0/skeleton/posec3d/rgbpose_conv3d/rgb_only_20230228-576b9f86.pth'\n",
    "pose_filepath = 'https://download.openmmlab.com/mmaction/v1.0/skeleton/posec3d/rgbpose_conv3d/pose_only_20230228-fa40054e.pth'"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loads checkpoint by http backend from path: https://download.openmmlab.com/mmaction/v1.0/skeleton/posec3d/rgbpose_conv3d/rgb_only_20230226-8bd9d8df.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: \"https://download.openmmlab.com/mmaction/v1.0/skeleton/posec3d/rgbpose_conv3d/rgb_only_20230226-8bd9d8df.pth\" to C:\\Users\\wxDai/.cache\\torch\\hub\\checkpoints\\rgb_only_20230226-8bd9d8df.pth\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loads checkpoint by http backend from path: https://download.openmmlab.com/mmaction/v1.0/skeleton/posec3d/rgbpose_conv3d/pose_only_20230226-fa40054e.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: \"https://download.openmmlab.com/mmaction/v1.0/skeleton/posec3d/rgbpose_conv3d/pose_only_20230226-fa40054e.pth\" to C:\\Users\\wxDai/.cache\\torch\\hub\\checkpoints\\pose_only_20230226-fa40054e.pth\n"
     ]
    }
   ],
   "source": [
    "rgb_ckpt = _load_checkpoint(rgb_filepath, map_location='cpu')['state_dict']\n",
    "pose_ckpt = _load_checkpoint(pose_filepath, map_location='cpu')['state_dict']"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [],
   "source": [
    "rgb_ckpt = {k.replace('backbone', 'backbone.rgb_path').replace('fc_cls', 'fc_rgb'): v for k, v in rgb_ckpt.items()}\n",
    "pose_ckpt = {k.replace('backbone', 'backbone.pose_path').replace('fc_cls', 'fc_pose'): v for k, v in pose_ckpt.items()}"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [],
   "source": [
    "old_ckpt = {}\n",
    "old_ckpt.update(rgb_ckpt)\n",
    "old_ckpt.update(pose_ckpt)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "# The difference is in dim-1\n",
    "def padding(weight, new_shape):\n",
    "    new_weight = weight.new_zeros(new_shape)\n",
    "    new_weight[:, :weight.shape[1]] = weight\n",
    "    return new_weight"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [
    "ckpt = cp.deepcopy(old_ckpt)\n",
    "name = 'backbone.rgb_path.layer3.0.conv1.conv.weight'\n",
    "ckpt[name] = padding(ckpt[name], (256, 640, 3, 1, 1))\n",
    "name = 'backbone.rgb_path.layer3.0.downsample.conv.weight'\n",
    "ckpt[name] = padding(ckpt[name], (1024, 640, 1, 1, 1))\n",
    "name = 'backbone.rgb_path.layer4.0.conv1.conv.weight'\n",
    "ckpt[name] = padding(ckpt[name], (512, 1280, 3, 1, 1))\n",
    "name = 'backbone.rgb_path.layer4.0.downsample.conv.weight'\n",
    "ckpt[name] = padding(ckpt[name], (2048, 1280, 1, 1, 1))\n",
    "name = 'backbone.pose_path.layer2.0.conv1.conv.weight'\n",
    "ckpt[name] = padding(ckpt[name], (64, 160, 3, 1, 1))\n",
    "name = 'backbone.pose_path.layer2.0.downsample.conv.weight'\n",
    "ckpt[name] = padding(ckpt[name], (256, 160, 1, 1, 1))\n",
    "name = 'backbone.pose_path.layer3.0.conv1.conv.weight'\n",
    "ckpt[name] = padding(ckpt[name], (128, 320, 3, 1, 1))\n",
    "name = 'backbone.pose_path.layer3.0.downsample.conv.weight'\n",
    "ckpt[name] = padding(ckpt[name], (512, 320, 1, 1, 1))\n",
    "ckpt = OrderedDict(ckpt)\n",
    "torch.save({'state_dict': ckpt}, 'rgbpose_conv3d_init.pth')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "data": {
      "text/plain": "_IncompatibleKeys(missing_keys=['backbone.rgb_path.layer2_lateral.conv.weight', 'backbone.rgb_path.layer3_lateral.conv.weight', 'backbone.pose_path.layer1_lateral.conv.weight', 'backbone.pose_path.layer1_lateral.bn.weight', 'backbone.pose_path.layer1_lateral.bn.bias', 'backbone.pose_path.layer1_lateral.bn.running_mean', 'backbone.pose_path.layer1_lateral.bn.running_var', 'backbone.pose_path.layer2_lateral.conv.weight', 'backbone.pose_path.layer2_lateral.bn.weight', 'backbone.pose_path.layer2_lateral.bn.bias', 'backbone.pose_path.layer2_lateral.bn.running_mean', 'backbone.pose_path.layer2_lateral.bn.running_var'], unexpected_keys=[])"
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_state_dict(ckpt, strict=False)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
